from mynumpy import *
from util import *

class estimator_iid:

    def __init__(self,base_est,bins):
        self.base_est  = base_est
        self.bins      = bins
        self.label     = 'iid'+str(bins)+'-'+self.base_est.label
        self.omega_dim = base_est.omega_dim*bins
        self.w_dim     = base_est.w_dim
        
    def sample_omega(self,num):
        #omega = [self.base_est.sample_omega(num) for k in range(self.bins)]
        omega = rand(self.omega_dim,num)
        return omega

    def logRs(self,omegas,w):
        logRS = np.vstack([self.base_est.logR(omega,w) for omega in np.split(omegas,self.bins)])
        return logRS

    def logR(self,omegas,w):
        logRS = self.logRs(omegas,w)
        return logsumexp(log(1/self.bins)+logRS,axis=0)

    def sample_zs(self,omegas,w):
        zs = np.vstack([self.base_est.sample_z(omega,w) for omega in np.split(omegas,self.bins)])
        return zs

    def sample_z(self,omegas,w):
        logRS = self.logRs(omegas,w)

        pi = exp(logRS - logsumexp(logRS,axis=0)) 

        i = sample_each_column(pi)

        zs = self.sample_zs(omegas,w)
        z  = one_from_each_col(zs,i)

        return z

    #def a(self,z,w):
        #return self.basedist.pdf(z,*dist_params(w))